import os
import os.path as osp
from typing import Tuple

import yaml
import gym
import numpy as np
from tqdm import tqdm

from diffgro.kinematic_llm import plan_waypoints, inflate_policy
from diffgro.utils import Parser, print_r, print_y, print_b
from diffgro.common.evaluations import evaluate, evaluate_complex
from train import eval_save, make_context


def make_env(args):
    print_r(f"<< Making Environment for {args.env_name}... >>")
    domain_name, task_name = args.env_name.split(".")
    if domain_name == "metaworld":
        env = gym.make(task_name, seed=args.seed)
    elif domain_name == "metaworld_complex":
        task_list = task_name.split("-")[:-1]
        if task_list[-1] == "variant":
            env = gym.make(
                "complex-variant-v2", seed=args.seed, task_list=task_list[:-1]
            )
        else:
            env = gym.make("complex-v2", seed=args.seed, task_list=task_list)
    print_y(
        f"Obs Space: {env.observation_space.shape}, Act Space: {env.action_space.shape}"
    )

    from diffgro.environments.variant import Categorical, VariantSpace

    if domain_name == "metaworld":
        print_r(f"Goal Resistance: {args.goal_resistance}")
        env.variant_space.variant_config["goal_resistance"] = Categorical(
            a=[args.goal_resistance]
        )
    elif domain_name == "metaworld_complex":
        env.variant_space.variant_config["goal_resistance"] = VariantSpace(
            {
                "handle": Categorical(a=[0]),
                "button": Categorical(a=[0]),
                "drawer": Categorical(a=[0]),
                "lever": Categorical(a=[0]),
                "door": Categorical(a=[0]),
            }
        )
    else:
        raise NotImplementedError

    env.variant_space.variant_config["arm_speed"] = Categorical(a=[1.0])
    env.variant_space.variant_config["wind_xspeed"] = Categorical(a=[1.0])
    env.variant_space.variant_config["wind_yspeed"] = Categorical(a=[1.0])

    return env, domain_name, task_name


class KinematicPromptWrapper:
    def __init__(self, env, context=None):
        self.env = env
        self.context = context

        self.model = None

    def predict(self, obs, deterministic):
        if self.model is None:
            domain_name, task_name = self.env.domain_name, self.env.task_name
            waypoints = plan_waypoints(
                domain_name, task_name, self.env, obs, self.context
            )
            self.model = inflate_policy(waypoints)

        action = self.model.get_action(obs) * 0.6
        return action, None, {"guided": False}

    def reset(self):
        self.model = None


def eval_episode(domain_name, task_name, env, video, context=None):

    frames = []
    obs, done, step = env.reset(), False, 0
    info = {
        "speed": [],
        "force": [],
        "force_axis": [],
        "energy": [],
        "actions": [],
    }

    waypoints = plan_waypoints(domain_name, task_name, env, obs, context)
    policy = inflate_policy(waypoints)
    for t in range(env.max_steps):
        action = policy.get_action(obs)
        obs, _, done, e_info = env.step(action)
        step += 1

        info["speed"].append(e_info["speed"])
        info["force"].append(e_info["force"])
        info["force_axis"].append(e_info["force_axis"])
        info["energy"].append(e_info["energy"])
        info["actions"].append(action[:3])

        if video:
            frame = env.render()
            frames.append(frame)

        if done:
            break

    return e_info["success"], step, frames, info


def test(args):
    domain_name, task_name = args.env_name.split(".")
    save_path = osp.join("./results/kinematic_llm", domain_name, task_name, args.tag)

    if task_name != "all":
        task_list = [task_name]
    else:
        task_lists = {
            "metaworld": [
                "door-open-variant-v2",
                "drawer-open-variant-v2",
                "drawer-close-variant-v2",
                "window-open-variant-v2",
                "window-close-variant-v2",
                "button-press-variant-v2",
                "peg-insert-side-variant-v2",
                "push-variant-v2",
                "faucet-open-variant-v2",
                "pick-place-variant-v2",
            ],
            "metaworld_complex": [
                "button-drawer-puck-stick-variant-v2",
                "drawer-puck-stick-button-variant-v2",
                "puck-drawer-button-stick-variant-v2",
            ],
        }
        task_list = task_lists.get(domain_name, [])

    tot_success = []
    num_episodes = args.n_episodes
    for task in task_list:
        args.env_name = f"{domain_name}.{task}"
        env, domain_name, env_name = make_env(args)
        env.domain_name = domain_name
        env.task_name = task

        contexts = [None]
        if args.context:
            context_path = osp.join("config", "contexts", domain_name, f"{task}.yml")
            with open(context_path) as f:
                contexts = yaml.load(f, Loader=yaml.FullLoader)[task]["text"]
                contexts = contexts[: 4 if domain_name == "metaworld" else 2]

        for context in contexts:
            if domain_name == "metaworld":
                success = evaluate(
                    KinematicPromptWrapper(env, context),
                    env,
                    domain_name,
                    task,
                    num_episodes,
                    True,
                    args.video,
                    save_path,
                    context=context,
                )
            else:
                success = evaluate_complex(
                    KinematicPromptWrapper(env, context),
                    env,
                    domain_name,
                    task,
                    num_episodes,
                    True,
                    args.video,
                    save_path,
                    context=context,
                )
            tot_success.extend(success)

    if len(task_list) > 1:
        eval_save(tot_success, save_path)


if __name__ == "__main__":
    args = Parser("train").parse_args()
    test(args)
